package org.misty.bloomfilter;

import org.checkerframework.checker.nullness.qual.Nullable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.List;
import java.util.Objects;

public class BloomFilter {
    private final static Logger log = LoggerFactory.getLogger(BloomFilter.class);
    private final static int HASH_INCREMENT = 0x61c88647;
    private final static String PUT_SCRIPT = "local k_bloomFilter = KEYS[1]\n" +
        "local v_indexes = ARGV\n" +
        "for i, v in ipairs(v_indexes) do\n" +
        "  redis.call('setbit', k_bloomFilter, v, 1)\n" +
        "end\n" +
        "return 'OK'";

    private final static String EXISTS_SCRIPT = "local k_bloomFilter = KEYS[1]\n" +
        "local v_indexes = ARGV\n" +
        "local n = table.getn(v_indexes)\n" +
        "local s = 0\n" +
        "for i, v in ipairs(v_indexes) do\n" +
        "  s = s + redis.call('getbit', k_bloomFilter, v)\n" +
        "end\n" +
        "return s == n or 0";

    private final IRedisClient<?> client;
    private final long numBits;
    private final int numHashFunctions;
    private final String redisKey;

    public BloomFilter(IRedisClient<?> redisClient, IConfig config) {
        Objects.requireNonNull(redisClient);
        Objects.requireNonNull(config);

        this.client = redisClient;
        long expectedInsertions = config.getExpectedInsertions();
        double fpp = config.getFpp();

        checkArgument(expectedInsertions >= 0, "Expected insertions expectedInsertions(%s) must be >= 0", expectedInsertions);
        checkArgument(fpp > 0.0, "False positive probability fpp(%s) must be > 0.0", fpp);
        checkArgument(fpp < 1.0, "False positive probability fpp(%s) must be < 1.0", fpp);

        this.numBits = optimalNumOfBits(expectedInsertions, fpp);
        this.numHashFunctions = optimalNumOfHashFunctions(expectedInsertions, numBits);

        Objects.requireNonNull(config.getRedisKey());
        this.redisKey = config.getRedisKey();
    }

    public boolean put(Object key) {
        String[] indexes = getIndexes(key, numBits, numHashFunctions);
        Object ret = client.eval(PUT_SCRIPT, String.class, List.of(redisKey), indexes);
        if (ret instanceof String) {
            return "OK".equals(ret);
        }
        return false;
    }

    public boolean exists(Object key) {
        String[] indexes = getIndexes(key, numBits, numHashFunctions);
        Object ret = client.eval(EXISTS_SCRIPT, Long.class, List.of(redisKey), indexes);
        if (ret instanceof Long) {
            return ret.equals(1L);
        }
        return false;
    }

    public static String[] getIndexes(Object k, long numBits, int numHashFunctions) {
        String[] idx = new String[numHashFunctions];
        long h = hash(k);
        for (int i = 0; i < numHashFunctions; i++) {
            idx[i] = String.valueOf(Math.abs(h) % numBits);
            h += HASH_INCREMENT;
        }
        return idx;
    }

    public static long hash(Object k) {
        int h = 0;
        if (k instanceof String) {
            h = murmur3_32(0, (String) k);
        } else {
            h ^= k.hashCode();
            h ^= (h >>> 20) ^ (h >>> 12);
            h ^= (h >>> 7) ^ (h >>> 4);
        }
        return h;
    }

    public static int murmur3_32(int seed, String value) {
        byte[] data = value.getBytes();
        int len = data.length;
        int res = seed;
        int pos = 0;
        int temp;
        int i;
        for (i = len; i >= 2; res = res * 5 + -430675100) {
            temp = data[pos++] & '\uffff' | data[pos++] << 16;
            i -= 2;
            temp *= -862048943;
            temp = Integer.rotateLeft(temp, 15);
            temp *= 461845907;
            res ^= temp;
            res = Integer.rotateLeft(res, 13);
        }

        if (i > 0) {
            temp = data[pos] * -862048943;
            temp = Integer.rotateLeft(temp, 15);
            temp *= 461845907;
            res ^= temp;
        }

        res ^= len * 2;
        res ^= res >>> 16;
        res *= -2048144789;
        res ^= res >>> 13;
        res *= -1028477387;
        res ^= res >>> 16;
        return res;
    }

    /*
     * Guava 29.0
     */
    private static long optimalNumOfBits(long n, double p) {
        if (p == 0) {
            p = Double.MIN_VALUE;
        }
        return (long) (-n * Math.log(p) / (Math.log(2) * Math.log(2)));
    }

    /*
     * Guava 29.0
     */
    private static int optimalNumOfHashFunctions(long n, long m) {
        // (m / n) * log(2), but avoid truncation due to division!
        return Math.max(1, (int) Math.round((double) m / n * Math.log(2)));
    }

    private static void checkArgument(boolean b, String errorMessageTemplate, @Nullable Object p1) {
        if (!b) {
            throw new IllegalArgumentException(String.format(errorMessageTemplate != null ? errorMessageTemplate : "%s", p1));
        }
    }

    public interface IConfig {
        long getExpectedInsertions();

        double getFpp();

        String getRedisKey();
    }

    public interface IRedisClient<S> {
        S scriptLoad(String script, Class<?> resultType);

        Object eval(String script, Class<?> resultType, List<String> keys, String... args);
    }
}
